import pandas as pd
from dateutil.parser import parse
from collections import Iterable
import re
import numpy as np
from .utils.factor_treament import winsorize, zscore



class FactorHandle(object):
    """
    整理出alphalens需要的格式，支持因子加权
    """
    def __init__(self, factor_engine, price_engine):
        self.factor_engine = factor_engine
        self.price_engine = price_engine
        self.group, self.sid = self._get_indname()
        self.last_date_range = None
        self.price = None
        self.factor_list = pd.read_sql("show tables", factor_engine)['Tables_in_factordb'].tolist()

    @property
    def factor(self):
        return self._factor

    @factor.setter
    def factor(self, value):
        if not (isinstance(value, str) or isinstance(value, list) \
            or isinstance(value, tuple) or isinstance(value, dict)):
            raise ValueError('factor must be a str or dict or list or tuple!')
        elif len(value) == 0:
            raise ValueError("factor cannot be empty!")
        if isinstance(value, str) and (value not in self.factor_list):
            raise ValueError("请输入正确的因子值！")
        elif isinstance(value, list):
            if len(value) != len(set(value) & set(self.factor_list)):
                raise ValueError("请输入正确的因子值！")
        elif isinstance(value, dict):
            if len(value) != len(set(value.keys()) & set(self.factor_list)):
                raise ValueError("请输入正确的因子值！")

        self._factor = value

    @property
    def date_range(self):
        return self._date_range

    @date_range.setter
    def date_range(self, value):
        if not (isinstance(value, str) or isinstance(value, list) \
            or isinstance(value, tuple)):
            raise ValueError('date must be str and the begindate is essential!')
        if len(value) == 0:
            raise ValueError('daterange connot be empty!')
        self._date_range = value


    @staticmethod
    def _date_range_handle(date_range):
        if isinstance(date_range, str):
            date_range = parse(date_range).strftime('%Y%m%d')
            return [date_range, '20990101']
        elif isinstance(date_range, Iterable):
            date_range = [parse(d).strftime('%Y%m%d') for d in date_range]
            return date_range

    @staticmethod
    def _factor_handle(factor):
        if isinstance(factor, str):
            factor = [factor]
        elif isinstance(factor, Iterable):
            pass
        return factor


    def _get_price(self, date_range):
        """
        date_range 应为处理好的结构
        """
        sql = """
            select S_INFO_WINDCODE sid, TRADE_DT trade_dt, S_DQ_ADJCLOSE price
            from nwind.dbo.AShareEODPrices
            where TRADE_DT >= '{0[0]}'
            and TRADE_DT <= '{0[1]}'
            """.format(date_range)
        # 存在缺失值价格的股票会alphalens自动处理
        price = pd.read_sql(sql, self.price_engine)
        price['trade_dt'] = pd.to_datetime(price['trade_dt'])
        price = price.pivot_table(values='price', columns='sid', index='trade_dt')
        price = price.reindex(columns=self.sid)
        return price

    def _get_onefactor(self, date_range, factor):
        """
        factor 为单个 factor名
        date_range 应为处理好的结构
        """
        sql = """
            select sid, trade_dt, value from factordb.{0}
            where trade_dt >= '{1[0]}'
            and trade_dt <= '{1[1]}'
            """.format(factor, date_range)
        factor_data = pd.read_sql(sql, self.factor_engine)
        # 对类似于pe开头的因子进行倒数处理
        if re.compile('p.+').match(factor) is not None:
            factor_data['value'] = np.power(factor_data['value'], -1)
        factor_data['trade_dt'] = pd.to_datetime(factor_data['trade_dt'])
        factor_data = factor_data[factor_data['sid'].isin(self.sid)]
        factor_data.set_index(['trade_dt', 'sid'], inplace=True)
        # 去winsorize zscore
        factor_data = factor_data.groupby(level=
                'trade_dt').apply(winsorize).groupby(level=
                'trade_dt').apply(zscore)
        return factor_data
    
    def _get_indname(self):
        sql = """
        select a.S_INFO_WINDCODE sid, b.INDUSTRIESNAME indname
        FROM nwind.dbo.AShareSWIndustriesClass as a, nWind.dbo.ASHAREINDUSTRIESCODE as b
        where a.CUR_SIGN = '1'
        and b.USED = '1'
        and b.LEVELNUM = '2'
        and LEFT(a.SW_IND_CODE, 4) = LEFT(b.INDUSTRIESCODE, 4)
        """
        group = pd.read_sql(sql, self.price_engine)
        group.set_index('sid', inplace=True)
        group = group.to_dict()['indname']
        sid = list(group.keys())
        return group, sid

    
    def get_all(self):
        date_range = self._date_range_handle(self.date_range)
        if self.last_date_range != date_range:
            price = self._get_price(date_range)
            self.price = price
        else:
            price = self.price
        factor = self._factor_handle(self.factor)
        if isinstance(factor, list):
            factor_data = sum((self._get_onefactor(date_range, f) for f in factor)) / len(factor)
        elif isinstance(factor, dict):
            factor_data = sum((self._get_onefactor(date_range, f) * w for f,w in factor.items())) / \
            sum((w for w in factor.values()))
        self.last_date_range = date_range        
        return factor_data, price, self.group

